import tensorflow as tf
import datetime
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
from flask import Flask, jsonify, request, render_template


model_dir = "C:/01.project/chatgpt/gpt-2/models/gpt2-large"
# 加载 GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)

# 加载 GPT-2 模型
model = TFGPT2LMHeadModel.from_pretrained(model_dir)

# 创建 Flask 实例
app = Flask(__name__)

# 定义 API 接口
@app.route('/generate', methods=['POST'])
def generate_text():
    # 从请求中获取文本
    input_text = request.get_json()['text']

    # 使用 tokenizer 对文本进行编码
    input_ids = tokenizer.encode(input_text, return_tensors='tf')

    # 使用模型生成文本
    output_ids = model.generate(input_ids, max_length=100, num_return_sequences=1)

    # 使用 tokenizer 对生成的文本进行解码
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # 返回生成的文本
    return jsonify({'generated_text': output_text})

@app.route('/', methods=['GET', 'POST'])
def index():
    start_time = datetime.datetime.now()
    if request.method == 'POST':
        input_text = request.form['input_text']
        input_ids = tokenizer.encode(input_text, return_tensors='tf')
        output_ids = model.generate(input_ids, max_length=100, num_return_sequences=1)
        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return render_template('index.html', generated_text="(请求耗时: {}秒)".format(datetime.datetime.now() - start_time) + output_text, title_lable="基于 GPT-2 的开源模型搭建(762 million)", top_lable="GPT-2开源模型(762 million)")
    else:
        return render_template('index.html', title_lable="基于 GPT-2 的开源模型搭建(762 million)", top_lable="GPT-2开源模型(762 million)")

# 启动应用
if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5000)
